//
// Created by Göksu Güvendiren on 2019-05-14.
//

#include "Scene.hpp"

void Scene::buildBVH() {
    printf(" - Generating BVH...\n\n");
    this->bvh = new BVHAccel(objects, 1, BVHAccel::SplitMethod::NAIVE);
}

Intersection Scene::intersect(const Ray &ray) const
{
    return this->bvh->Intersect(ray);
}

void Scene::sampleLight(Intersection &pos, float &pdf) const
{
    float emit_area_sum = 0;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        if (objects[k]->hasEmit()){
            emit_area_sum += objects[k]->getArea();
        }
    }
    float p = get_random_float() * emit_area_sum;
    emit_area_sum = 0;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        if (objects[k]->hasEmit()){
            emit_area_sum += objects[k]->getArea();
            if (p <= emit_area_sum){
                objects[k]->Sample(pos, pdf);
                pos.obj = objects[k];
                break;
            }
        }
    }
}

bool Scene::trace(
        const Ray &ray,
        const std::vector<Object*> &objects,
        float &tNear, uint32_t &index, Object **hitObject)
{
    *hitObject = nullptr;
    for (uint32_t k = 0; k < objects.size(); ++k) {
        float tNearK = kInfinity;
        uint32_t indexK;
        Vector2f uvK;
        if (objects[k]->intersect(ray, tNearK, indexK) && tNearK < tNear) {
            *hitObject = objects[k];
            tNear = tNearK;
            index = indexK;
        }
    }


    return (*hitObject != nullptr);
}

// Implementation of Path Tracing
Vector3f Scene::castRay(const Ray &ray, int depth) const
{
    // TO DO Implement Path Tracing Algorithm here
    auto target = intersect(ray);
    if (!target.happened)
    {
        return Vector3f(0);
    }
    auto N = target.normal;

    // 1 shade (p, wo)
    auto p = target.coords;
    auto wo = ray.direction;
    
    // 2    sampleLight ( inter , pdf_light )
    Intersection inter;
    float pdf_light;
    sampleLight(inter, pdf_light);

    // 3    Get x, ws , NN , emit from inter
    auto x = inter.coords;
    auto ws = (x - p).normalized();
    auto NN = -inter.normal;
    auto emit = inter.emit;

    // 4    Shoot a ray from p to x
    auto test_light_middle = intersect(Ray(p, ws));
            
    // 5    If the ray is not blocked in the middle
    // 6    L_dir = emit * eval (wo , ws , N) * dot (ws , N) * dot (ws , NN) / |x-p |^2 / pdf_light
    Vector3f L_dir = Vector3f();
    if ((test_light_middle.coords - inter.coords).norm() < 0.01)
    {
        L_dir = emit * target.m->eval(wo, ws, N) * dotProduct(ws, N) * dotProduct(ws, NN) / std::pow((x-p).norm(), 2) / pdf_light;
    }
    
    // 8    L_indir = 0.0
    Vector3f L_indir = Vector3f();
    // 9    Test Russian Roulette with probability RussianRoulette
    if (get_random_float() < RussianRoulette)
    {
        // 10   wi = sample (wo , N)
        auto wi = target.m->sample(wo, N);
        // 11   Trace a ray r(p, wi)
        auto test_ref_obj_inter = intersect(Ray(p, wi));
        // 12   If ray r hit a non - emitting object at q
        if (test_ref_obj_inter.happened && !test_ref_obj_inter.obj->hasEmit())
        {
            // 13   L_indir = shade (q, wi) * eval (wo , wi , N) * dot (wi , N) / pdf (wo , wi , N) / RussianRoulette
            L_indir = castRay(Ray(p, wi), depth + 1) * target.m->eval (wo, wi, N) * dotProduct(wi, N) / target.m->pdf (wo , wi , N) / RussianRoulette;
        }
    }
    // 15   Return L_dir + L_indir
    if (target.obj->hasEmit())
    {
        return L_dir + L_indir + target.m->getEmission();    
    }
    
    return L_dir + L_indir;
}